Skip to content

[KSM] support keep sampling mask#7460

Open
zeroRains wants to merge 8 commits intoPaddlePaddle:release/2.6from
zeroRains:ksm_2.6
Open

[KSM] support keep sampling mask#7460
zeroRains wants to merge 8 commits intoPaddlePaddle:release/2.6from
zeroRains:ksm_2.6

Conversation

@zeroRains
Copy link
Copy Markdown
Contributor

@zeroRains zeroRains commented Apr 17, 2026

Motivation

本 PR 为 FastDeploy 实现 Keep Sampling Mask (KSM) 功能,用于在 top_p/top_k 采样过程中返回保留的词汇表索引列表(稀疏格式)。

当前推理引擎在执行 top_p/top_k 采样时,仅返回最终采样的 token ID,但不提供采样过程中的候选集合信息。这导致:

  1. 可解释性不足:无法了解模型在每个 token 生成时考虑了哪些候选词
  2. 调试困难:难以分析采样策略(如 top_p=0.9, top_k=50)的实际效果
  3. 下游应用受限:无法基于候选集合实现自定义后处理逻辑logprobs
  4. 归一化不完整:返回的 logprobs 未基于截断后的候选集合重新归一化

本 PR 通过新增 sampling_mask 字段,记录每个 token 采样时保留的词汇表索引(稀疏格式),并提供基于候选集合的 logprobs 重归一化功能。

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

sampler.py 下新增_compute_sampling_mask方法
添加启动参数--enable-keep-sampling-mask
logprobs.py 下新增logz的renormalize函数,logprobs_renormalize_with_logz
pre_and_post_process.py的post_processs中调用renormalize函数

Usage or Command

服务启动指令:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
MODEL_PATH="/root/paddlejob/tmpspace/GLM-4.5-Air/"
python -m fastdeploy.entrypoints.openai.api_server \
    --port 9293 \
    --host $(hostname -i) \
    --model "$MODEL_PATH" \
    --disable-custom-all-reduce \
    --tensor-parallel-size 8 \
    --max-model-len 131072 \
    --max-num-seqs 32 \
    --gpu-memory-utilization 0.9 \
    --graph-optimization-config '{"use_cudagraph":true}' \
    --enable-logprob \
    --enable-keep-sampling-mask \
    --speculative-config '{"method":"mtp","num_speculative_tokens":1,"num_model_steps":1,"model":"'$MODEL_PATH'"}'

Accuracy Tests

yes

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 17, 2026 06:02
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 17, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在为推理服务增加 keep sampling mask 输出能力:在 top_p/top_k 截断采样后,将每步保留下来的词表索引以稀疏形式返回/流式返回,便于客户端侧做可解释性与调试分析,并补充相应的 CLI 开关与端到端测试。

Changes:

  • 新增启动参数 --enable-keep-sampling-mask,贯通 Engine/Worker/Sampler/TokenProcessor/OpenAI Serving 的开关传递。
  • 在采样阶段计算稀疏 sampling_mask(以及 logZ),并在非 FD_USE_GET_SAVE_OUTPUT_V1 路径通过 ZMQ side-channel 发送到 token_processor,再输出到 OpenAI 响应。
  • 新增/更新单测与 e2e 测试覆盖 sampling_mask 在流式与非流式响应中的格式与一致性。

Reviewed changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tests/output/test_process_batch_output.py 为测试构造的 processor 补齐 use_sampling_mask 字段初始化。
tests/entrypoints/openai/test_max_streaming_tokens.py 更新调用以适配 chat choice 新增的 sampling_mask_list 参数。
tests/e2e/test_ernie_21b_mtp.py e2e:启动参数开启 keep sampling mask,并新增流式/非流式/不同 top_p 的校验用例。
fastdeploy/worker/worker_process.py Worker CLI 新增 --enable-keep-sampling-mask(含下划线与短横线别名)。
fastdeploy/worker/output.py SamplerOutput 新增 sampling_mask 与 logz_per_batch 字段(稀疏 mask 与 logZ)。
fastdeploy/worker/gpu_model_runner.py 读取配置开关;非 V1 路径创建 sampling_mask ZMQ client;prepare_inputs 传 keep_sampling_mask;save_output 透传 sampling_mask_zmq_client。
fastdeploy/output/token_processor.py 非 V1 路径新增 sampling_mask ZMQ server;每步接收 mask 并写入 RequestOutput.outputs。
fastdeploy/output/stream_transfer_data.py StreamTransferData 新增 sampling_mask 字段以承载稀疏 mask。
fastdeploy/model_executor/pre_and_post_process.py stream transfer data 增加 sampling_mask;save_output_* 增加 side-channel 发送;新增基于 logZ 的 logprobs 归一化步骤。
fastdeploy/model_executor/layers/sample/sampler.py 新增 _compute_sampling_mask;normal 与 speculative 路径在采样前计算 sampling_mask/logZ 并写入 SamplerOutput。
fastdeploy/model_executor/layers/sample/meta_data.py SamplingMetadata 新增 keep_sampling_mask 字段。
fastdeploy/model_executor/layers/sample/logprobs.py build_output_logprobs 返回值新增 output_logits;新增 logprobs_renormalize_with_logz。
fastdeploy/entrypoints/openai/serving_chat.py 在 stream/full 响应中输出 sampling_mask;新增 _make_sampling_mask_list 并在 choice 汇总时扁平化。
fastdeploy/entrypoints/openai/protocol.py OpenAI 协议响应模型新增 sampling_mask 字段(List[List[int]])。
fastdeploy/engine/request.py CompletionOutput 新增 sampling_mask 字段并纳入 to_dict 输出。
fastdeploy/engine/engine.py worker_store_true_flag 增加 enable_keep_sampling_mask,启动 worker 时透传开关。
fastdeploy/engine/common_engine.py 同 engine.py:透传 enable_keep_sampling_mask 到 worker 启动参数。
fastdeploy/engine/args_utils.py EngineArgs/CLI 新增 --enable-keep-sampling-mask 参数与说明。
fastdeploy/config.py ModelConfig 新增 enable_keep_sampling_mask 默认字段。

Comment thread fastdeploy/entrypoints/openai/serving_chat.py
Comment thread fastdeploy/model_executor/layers/sample/logprobs.py
Comment thread fastdeploy/worker/output.py
Comment thread fastdeploy/output/stream_transfer_data.py
Comment thread fastdeploy/model_executor/layers/sample/sampler.py Outdated
Comment thread fastdeploy/output/token_processor.py
Comment thread fastdeploy/model_executor/pre_and_post_process.py
Comment thread fastdeploy/model_executor/pre_and_post_process.py
PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings April 17, 2026 07:30
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 19 out of 19 changed files in this pull request and generated 5 comments.

Comment on lines +625 to +638
# Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req).
real_bsz = model_output.accept_num.shape[0]
accept_nums = model_output.accept_num[:real_bsz].flatten().tolist()
mask_dict = {}
offset = 0
total_masks = len(sampler_output.sampling_mask)
for i, n in enumerate(accept_nums):
n = max(int(n), 0)
if n > 0:
# List of n sparse index arrays, one per accepted token
mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]]
offset += n
if offset != total_masks:
raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}")
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speculative 路径发送 sampling_mask 时用的是 model_output.accept_num 来做分组并按 i 构造 mask_dict key,但上面 speculate_save_output(_topk) 的输出会经过 index_to_batch_id + enable_pd_reorder 恢复到原始 batch 顺序;如果开启 PD reorder,这里未对 sampler_output.sampling_mask / accept_num / logz_per_batch 做一致的恢复排序,mask_dict 的 key/分组将与 token_processor 侧的 batch_id 不一致。建议:在生成 mask_dict 前先对 accept_num 与 sampling_mask 做与输出一致的 recover/reorder(可复用 recover_share_inputs["accept_num_cpu"] 或扩展 recover_batch_index_for_sampler_output),并同步重排 logz_per_batch。

Suggested change
# Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req).
real_bsz = model_output.accept_num.shape[0]
accept_nums = model_output.accept_num[:real_bsz].flatten().tolist()
mask_dict = {}
offset = 0
total_masks = len(sampler_output.sampling_mask)
for i, n in enumerate(accept_nums):
n = max(int(n), 0)
if n > 0:
# List of n sparse index arrays, one per accepted token
mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]]
offset += n
if offset != total_masks:
raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}")
# Recover it to the same batch order as speculate_save_output(_topk) before grouping by request.
real_bsz = recover_share_inputs["accept_num_cpu"].shape[0]
raw_accept_nums = model_output.accept_num[:real_bsz].flatten().tolist()
recovered_accept_nums = recover_share_inputs["accept_num_cpu"][:real_bsz].flatten().tolist()
total_masks = len(sampler_output.sampling_mask)
sampling_mask_groups = []
offset = 0
for n in raw_accept_nums:
n = max(int(n), 0)
sampling_mask_groups.append(sampler_output.sampling_mask[offset : offset + n])
offset += n
if offset != total_masks:
raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}")
recovered_sampling_mask_groups = [[] for _ in range(real_bsz)]
if model_output.index_to_batch_id is None:
batch_id_map = list(range(real_bsz))
else:
batch_id_map = np.asarray(model_output.index_to_batch_id[:real_bsz]).flatten().tolist()
for i, group in enumerate(sampling_mask_groups):
batch_id = int(batch_id_map[i])
if batch_id < 0 or batch_id >= real_bsz:
raise ValueError(f"sampling_mask batch_id out of range: {batch_id}, real_bsz={real_bsz}")
recovered_sampling_mask_groups[batch_id] = group
mask_dict = {}
for i, n in enumerate(recovered_accept_nums):
n = max(int(n), 0)
if len(recovered_sampling_mask_groups[i]) != n:
raise ValueError(
f"sampling_mask group size mismatch for batch {i}: "
f"expected {n}, got {len(recovered_sampling_mask_groups[i])}"
)
if n > 0:
# List of n sparse index arrays, one per accepted token.
mask_dict[i] = [arr.tolist() for arr in recovered_sampling_mask_groups[i]]

Copilot uses AI. Check for mistakes.
Comment thread fastdeploy/model_executor/layers/sample/sampler.py
Comment on lines +234 to +246
logz = paddle.to_tensor(logz, dtype=logprobs.dtype)
# Renormalize: log π_masked = log π_full - log Z_K
# Only normalize valid candidates; padding positions use -inf
valid_mask = paddle.isfinite(logprobs)
normalized_logprobs = paddle.where(
valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf"))
)
# Update logprobs_tensors with normalized values
return LogprobsTensors(
logprob_token_ids=logprobs_tensors.logprob_token_ids,
logprobs=normalized_logprobs,
selected_token_ranks=logprobs_tensors.selected_token_ranks,
)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logprobs_renormalize_with_logz 目前对所有 isfinite 的位置统一做 logprobs - logZ_K,但 logprobs_tensors 里的 top-k 项是从“全量分布”topk 取出的,未必全部落在 top_p/top_k 截断后的候选集合 K 内(尤其当 top_p 很小且 max_logprobs 较大时)。这会导致返回的“重归一化 logprobs”仍包含候选集之外 token 的有限值,不符合截断分布语义。建议结合 sampling_mask(或 candidate set)把不在 K 内的 token logprobs 置为 -inf / None,并仅对 K 内条目做重归一化,或改为直接在截断后的分布上构造 logprobs 输出。

Copilot uses AI. Check for mistakes.
Comment thread tests/e2e/test_ernie_21b_mtp.py
Comment thread fastdeploy/engine/args_utils.py
PaddlePaddle-bot

This comment was marked as outdated.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 17, 2026

Codecov Report

❌ Patch coverage is 77.53623% with 31 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.6@f4f7760). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/pre_and_post_process.py 57.69% 6 Missing and 5 partials ⚠️
fastdeploy/entrypoints/openai/serving_chat.py 38.46% 6 Missing and 2 partials ⚠️
fastdeploy/output/token_processor.py 60.00% 4 Missing and 2 partials ⚠️
...astdeploy/model_executor/layers/sample/logprobs.py 55.55% 4 Missing ⚠️
fastdeploy/model_executor/layers/sample/sampler.py 96.22% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.6    #7460   +/-   ##
==============================================
  Coverage               ?   73.87%           
==============================================
  Files                  ?      376           
  Lines                  ?    53130           
  Branches               ?     8300           
==============================================
  Hits                   ?    39250           
  Misses                 ?    11129           
  Partials               ?     2751           
Flag Coverage Δ
GPU 73.87% <77.53%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings April 17, 2026 10:58
PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 5 comments.

Comments suppressed due to low confidence (1)

tests/output/test_process_batch_draft_tokens.py:39

  • 这里 cfg.model_config 是 MagicMock,若未显式设置 enable_keep_sampling_mask=False,TokenProcessor 可能把 keep_sampling_mask 当成开启并尝试创建/绑定 ZMQ IPC server(路径包含固定的 "9700"),在测试并发或重复执行时容易冲突。建议在 cfg.model_config 上补充 enable_keep_sampling_mask = False(除非本用例确实要覆盖该功能并做好 socket 清理/隔离)。
        # 模拟 cfg
        cfg = MagicMock()
        cfg.speculative_config = MagicMock()
        cfg.parallel_config.local_data_parallel_id = 0
        cfg.parallel_config.engine_worker_queue_port = ["9700"]
        cfg.speculative_config.method = "mtp"
        cfg.speculative_config.num_speculative_tokens = 3
        cfg.model_config = MagicMock()
        cfg.model_config.enable_logprob = True

Comment on lines +749 to +755
# where the value is a list[int] or list[list[int]] of allowed token ids
sampling_masks_per_request = {}
if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"):
_, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True)
if mask_data is not None and isinstance(mask_data, dict):
sampling_masks_per_request = mask_data

Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用 block=True 同步等待 sampling_mask side-channel 消息,缺少超时/降级路径:一旦 worker 未发送(例如 client 未创建/发送失败/某些 runner 未接入该 side-channel),TokenProcessor 会永久阻塞,导致整体推理挂死。建议改为非阻塞轮询(block=False)并在缺失时允许该 step 继续,或增加可配置超时并打印错误日志,避免死锁。

Suggested change
# where the value is a list[int] or list[list[int]] of allowed token ids
sampling_masks_per_request = {}
if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"):
_, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True)
if mask_data is not None and isinstance(mask_data, dict):
sampling_masks_per_request = mask_data
# where the value is a list[int] or list[list[int]] of allowed token ids.
# Use a non-blocking receive so a missing side-channel message does not
# stall the whole token processing loop.
sampling_masks_per_request = {}
if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"):
mask_data = None
try:
_, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=False)
except zmq.Again:
mask_data = None
except Exception:
llm_logger.exception(
"Failed to receive sampling_mask side-channel message; "
"continuing without sampling mask for this step."
)
mask_data = None
if mask_data is not None:
if isinstance(mask_data, dict):
sampling_masks_per_request = mask_data
else:
llm_logger.warning(
"Ignore invalid sampling_mask side-channel payload type: %s",
type(mask_data).__name__,
)

Copilot uses AI. Check for mistakes.
def setup_method(self):
self.mock_cfg = MagicMock()
self.mock_cfg.parallel_config.local_data_parallel_id = 0
self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"]
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 cfg/model_config 使用 MagicMock 时,TokenProcessor.initgetattr(cfg.model_config, "enable_keep_sampling_mask", False) 会返回一个 truthy 的 MagicMock,导致单测意外开启 keep_sampling_mask 并尝试 bind 固定的 IPC 地址(/dev/shm/sampling_mask_output_rank_0_9700.socket),容易在并行/重复运行时出现“Address already in use”或资源泄漏。建议在 mock_cfg.model_config 上显式设置 enable_keep_sampling_mask=False(或 patch envs.FD_USE_GET_SAVE_OUTPUT_V1=True 以避免创建该 server)。

Suggested change
self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"]
self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"]
self.mock_cfg.model_config.enable_keep_sampling_mask = False

Copilot uses AI. Check for mistakes.
"""为 TokenProcessor 测试设置通用的 mock 对象。"""
self.mock_cfg = MagicMock()
self.mock_cfg.parallel_config.local_data_parallel_id = 0
self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"]
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该 setUp 使用 MagicMock 构造 cfg 时同样存在 enable_keep_sampling_mask 被 MagicMock 误判为 True 的风险,TokenProcessor 可能在单测中意外创建并 bind sampling_mask 的 ZMQ IPC socket,造成端口/文件冲突和测试不稳定。建议显式设置 self.mock_cfg.model_config.enable_keep_sampling_mask = False。

Suggested change
self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"]
self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"]
self.mock_cfg.model_config.enable_keep_sampling_mask = False

Copilot uses AI. Check for mistakes.
Comment on lines 29 to 35
def setUp(self):
self.cfg = MagicMock()
self.cfg.model_config.enable_logprob = True
self.cfg.speculative_config.method = None
self.cfg.parallel_config.local_data_parallel_id = 0
self.cfg.parallel_config.engine_worker_queue_port = ["9700"]
self.cached_generated_tokens = MagicMock()
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该测试 cfg 通过 MagicMock 构造,TokenProcessor 初始化时可能将 enable_keep_sampling_mask 读取为 truthy 的 MagicMock,从而在单测里意外创建并 bind sampling_mask 的 ZMQ IPC server(固定 name/端口),导致用例间冲突或资源泄漏。建议在 cfg.model_config 上显式设置 enable_keep_sampling_mask=False。

Copilot uses AI. Check for mistakes.
Comment on lines +380 to +386
# Renormalize logprobs to match truncated sampling distribution (when enabled).
if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None:
sampler_output.logprobs_tensors = logprobs_renormalize_with_logz(
sampler_output.logprobs_tensors.logprobs,
sampler_output.logz_per_batch,
sampler_output.logprobs_tensors,
)
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里对 logprobs 做 renormalize 时需要避免与 Sampler.compute_logprobs 中的 top_p_normalized_logprobs 逻辑重复归一化;否则当请求侧已开启 top_p_normalized_logprobs(top_p!=1.0)时会出现二次减去 logZ,导致返回的 logprobs 数值错误。建议按 request/token 维度判断是否已做过 top_p 归一化,再决定是否应用 logz_per_batch(或仅对未归一化的行应用)。

Copilot uses AI. Check for mistakes.
PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings April 20, 2026 04:59
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 6 comments.

Comments suppressed due to low confidence (1)

tests/output/test_process_batch_draft_tokens.py:39

  • 该用例里 cfg/model_config 使用 MagicMock,未显式设置 enable_keep_sampling_mask 会导致 TokenProcessor 在 setUp 时尝试创建并 bind sampling_mask 的 ZMQ IPC server(固定 socket 文件名),从而引入测试间冲突/资源泄漏风险。建议在这里把 cfg.model_config.enable_keep_sampling_mask 显式设为 False(或在 teardown 关闭 server)。
        # 模拟 cfg
        cfg = MagicMock()
        cfg.speculative_config = MagicMock()
        cfg.parallel_config.local_data_parallel_id = 0
        cfg.parallel_config.engine_worker_queue_port = ["9700"]
        cfg.speculative_config.method = "mtp"
        cfg.speculative_config.num_speculative_tokens = 3
        cfg.model_config = MagicMock()
        cfg.model_config.enable_logprob = True

Comment on lines +197 to +205
k_per_row = topp_mask.astype("int32").sum(axis=-1, keepdim=True) # [B,1]
# boundary_idx = last True position (k-1), clamp for safety
boundary_idx = (k_per_row - 1).clip(min=0) # [B, 1]
boundary_prob = paddle.take_along_axis(
renorm_sorted_probs,
boundary_idx,
axis=-1,
) # [B, 1]
topp_mask = topp_mask | (renorm_sorted_probs >= boundary_prob)
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_compute_sampling_mask() 里 boundary_idx 由 bool sum 得到的是 int32,直接传给 paddle.take_along_axis 可能触发索引 dtype 不兼容(Paddle 通常要求 int64 索引),导致启用 keep_sampling_mask 时运行时报错。建议在 take_along_axis 前显式把 boundary_idx cast 到 int64。

Copilot uses AI. Check for mistakes.
Comment on lines +456 to +461
# Send sampling_mask via ZMQ side-channel when enabled.
if sampler_output.sampling_mask is not None and model_output.mp_rank == 0:
# sampling_mask is List[np.ndarray] of sparse int indices, one array per request.
mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)}

sampling_mask_zmq_client.send_pyobj(mask_dict)
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_output_normal() 在 enable_pd_reorder=True 时会调用 recover_batch_index_for_sampler_output(),但该函数当前不会重排 sampler_output.sampling_mask(以及 logz_per_batch)。这样会导致 sampling_mask 与 recover 后的 sampled_token_ids / batch_id 对不上,返回给客户端的 sampling_mask 可能错配到其他 request。建议在 recover 流程里把 sampling_mask/logz_per_batch 也按 index_to_batch_id 同步重排,或在发送 mask_dict 前基于 index_to_batch_id 做一次 list 重排。

Copilot uses AI. Check for mistakes.
Comment on lines +696 to +709
# Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req).
real_bsz = model_output.accept_num.shape[0]
accept_nums = model_output.accept_num[:real_bsz].flatten().tolist()
mask_dict = {}
offset = 0
total_masks = len(sampler_output.sampling_mask)
for i, n in enumerate(accept_nums):
n = max(int(n), 0)
if n > 0:
# List of n sparse index arrays, one per accepted token
mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]]
offset += n
if offset != total_masks:
raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}")
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_output_specualate() 里 sampling_mask 的分组/发送同样没有考虑 enable_pd_reorder 的 index_to_batch_id 重排(recover_batch_index_for_sampler_output 也不会处理 sampling_mask)。在开启 PD reorder 时这里会出现 per-request 的 sampling_mask 分发错位。建议在构造 mask_dict 前先对 sampling_mask 与 accept_num 的对齐关系做恢复(要么扩展 recover_batch_index_for_sampler_output 支持 sampling_mask/logz_per_batch,要么在这里显式按 index_to_batch_id 重排后再按 accept_num 分组)。

Suggested change
# Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req).
real_bsz = model_output.accept_num.shape[0]
accept_nums = model_output.accept_num[:real_bsz].flatten().tolist()
mask_dict = {}
offset = 0
total_masks = len(sampler_output.sampling_mask)
for i, n in enumerate(accept_nums):
n = max(int(n), 0)
if n > 0:
# List of n sparse index arrays, one per accepted token
mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]]
offset += n
if offset != total_masks:
raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}")
# It is flattened in sampler-output order, so when PD reorder is enabled we must
# first restore the per-request chunk order with index_to_batch_id before grouping.
real_bsz = model_output.accept_num.shape[0]
accept_nums_for_sampling_mask = model_output.accept_num[:real_bsz].flatten().tolist()
total_masks = len(sampler_output.sampling_mask)
restored_sampling_mask = [[] for _ in range(real_bsz)]
offset = 0
for sampler_idx, n in enumerate(accept_nums_for_sampling_mask):
n = max(int(n), 0)
next_offset = offset + n
mask_chunk = sampler_output.sampling_mask[offset:next_offset]
if len(mask_chunk) != n:
raise ValueError(
f"sampling_mask length mismatch while grouping: expected {n}, got {len(mask_chunk)} "
f"for sampler_idx {sampler_idx}"
)
if model_output.enable_pd_reorder:
batch_id = int(model_output.index_to_batch_id[sampler_idx])
else:
batch_id = sampler_idx
restored_sampling_mask[batch_id] = mask_chunk
offset = next_offset
if offset != total_masks:
raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}")
mask_dict = {}
for batch_id, mask_chunk in enumerate(restored_sampling_mask):
if mask_chunk:
# List of sparse index arrays, one per accepted token for this request.
mask_dict[batch_id] = [arr.tolist() for arr in mask_chunk]

Copilot uses AI. Check for mistakes.
Comment on lines +456 to +461
# Send sampling_mask via ZMQ side-channel when enabled.
if sampler_output.sampling_mask is not None and model_output.mp_rank == 0:
# sampling_mask is List[np.ndarray] of sparse int indices, one array per request.
mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)}

sampling_mask_zmq_client.send_pyobj(mask_dict)
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在发送 sampling_mask 时未检查 sampling_mask_zmq_client 是否为 None。当前参数声明允许为 None,一旦上游配置/注入不一致(例如 keep_sampling_mask 被置真但 client 未初始化)会直接 AttributeError 并中断推理线程。建议在 send_pyobj 前加显式判空/断言,并给出更清晰的错误信息。

Copilot uses AI. Check for mistakes.
Comment on lines +196 to +198
# logZ_K for each request: log(sum(probs in candidate set K))
# Used for renormalizing logprobs to match the truncated sampling distribution.
# Shape: [num_reqs]
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SamplerOutput.logz_per_batch 的注释写的是“Shape: [num_reqs]”,但在 speculative decoding 路径里 logz_per_batch 实际是按 accepted token 展平计算的(shape 更接近 [total_accepted_tokens]),仅用于 logprobs 重归一化。建议更新注释/命名以反映两种路径的真实维度,避免后续误用。

Suggested change
# logZ_K for each request: log(sum(probs in candidate set K))
# Used for renormalizing logprobs to match the truncated sampling distribution.
# Shape: [num_reqs]
# logZ_K used for logprob renormalization:
# - Non-speculative decoding: per-request values with shape [num_reqs].
# - Speculative decoding: flattened per-accepted-token values with shape
# approximately [total_accepted_tokens].
# Callers MUST NOT assume this is always shaped by num_reqs; interpret the
# dimension according to the current decoding path.

Copilot uses AI. Check for mistakes.
Comment on lines +49 to +52
# 1-D int32 numpy array of vocab indices retained by top_p/top_k for
# this request. Sparse format: only retained positions, not a dense
# vocab-sized bool mask.
sampling_mask: Optional[np.array] = None
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StreamTransferData.sampling_mask 的类型注解写成了 Optional[np.array],但 np.array 是函数而不是类型;这里应使用 np.ndarray(或更具体的 np.ndarray[np.int32] 等)。否则会误导静态检查/IDE。

Copilot uses AI. Check for mistakes.
PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-20 22:08:55

📋 Review 摘要

PR 概述:为 FastDeploy 新增 Keep Sampling Mask (KSM) 功能,在 top_p/top_k 采样时返回保留的词汇表索引稀疏列表,并支持 logprobs 重归一化。
变更范围:sampler、logprobs、pre_and_post_process、token_processor、serving_chat、protocol、config
影响面 Tag[OP] [APIServer] [Engine]

📝 PR 规范检查

  1. PR 标题中的 [KSM] 不在官方 Tag 列表中,应使用 [Feature]
  2. 本 PR 目标分支为 release/2.6(非 develop),根据规范应在标题前添加 [Cherry-Pick] 标签,并在末尾附上原 PR ID。

标题建议(可直接复制):

  • [Cherry-Pick][Feature] Support keep sampling mask for top_p/top_k candidate set(#原PR_ID)

问题

级别 文件 概述
🔴 Bug pre_and_post_process.py:461 sampling_mask_zmq_client 可能为 None 时直接调用 send_pyobj
🔴 Bug pre_and_post_process.py:710 同上,speculative 路径存在相同风险
🟡 建议 pre_and_post_process.py:386 logprobs 隐式重归一化可能导致已有用户 logprobs 语义变化
🟡 建议 sampler.py:114 _compute_sampling_mask 每步执行全量 argsort,性能开销较大

总体评价

功能实现完整,覆盖了 MTP 和非 MTP 两种路径,测试用例覆盖了流式/非流式和 top_p 对比场景。主要问题是 ZMQ client 缺少空值防护可能导致运行时崩溃,以及 logprobs 重归一化对已有 --enable-logprob 用户是隐式的行为变更,建议显式控制。

# sampling_mask is List[np.ndarray] of sparse int indices, one array per request.
mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)}

sampling_mask_zmq_client.send_pyobj(mask_dict)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug sampling_mask_zmq_client 参数类型为 Optional[ZmqIpcClient](默认 None),但此处未做空值检查直接调用 send_pyobj

当前逻辑仅检查 sampler_output.sampling_mask is not None and model_output.mp_rank == 0,如果配置不一致(例如 enable_keep_sampling_mask=True 但 ZMQ client 因某些原因未初始化),将触发 AttributeError: 'NoneType' object has no attribute 'send_pyobj'

建议修复:

if sampler_output.sampling_mask is not None and model_output.mp_rank == 0 and sampling_mask_zmq_client is not None:

offset += n
if offset != total_masks:
raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}")
sampling_mask_zmq_client.send_pyobj(mask_dict)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bugsave_output_normal 中相同的问题:sampling_mask_zmq_client 可能为 None(参数默认值为 None),调用前需添加空值防护。

建议修复:

if sampler_output.sampling_mask is not None and model_output.mp_rank == 0 and sampling_mask_zmq_client is not None:

)

# Renormalize logprobs to match truncated sampling distribution (when enabled).
if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 当同时启用 --enable-logprob--enable-keep-sampling-mask 时,此处会将 logprobs 隐式地基于候选集重归一化。这改变了现有 logprobs 的语义(从全词表 log_softmax 变为截断候选集归一化),可能影响已依赖原始 logprobs 值的下游用户。

建议:

  1. 将重归一化作为独立的可选行为(例如 --logprobs-renormalize 参数),或
  2. 在 API 响应中新增独立字段(如 normalized_logprobs)而非覆盖原 logprobs,或
  3. 至少在文档/启动日志中明确说明 --enable-keep-sampling-mask 会修改 logprobs 输出语义。

top_p: paddle.Tensor,
top_k: Optional[paddle.Tensor] = None,
top_k_list: Optional[list] = None,
) -> tuple[List[np.ndarray], np.ndarray]:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _compute_sampling_mask 在每个 decode step 对 [B, vocab_size] 张量执行 argsort(O(B·V·logV)),对于大词表模型(如 GLM-4 的 151k vocab)开销显著。

top_k_top_p_sampling 内部已经执行了类似的排序逻辑。考虑:

  1. 复用 top_k_top_p_sampling 中已有的排序结果来构建 mask,避免重复排序;
  2. 或将 mask 计算与采样合并为一个 kernel 调用。

此外,tuple[List[np.ndarray], np.ndarray] 使用了 Python 3.9+ 的小写 tuple 语法,如项目需兼容 3.8 应改为 Tuple[List[np.ndarray], np.ndarray]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants